# NSW CATE ANALYSIS - SAMPLE SIZE ANALYSIS WITH THEORETICAL BOUNDS

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

def theoretical_bound(m, beta, N, delta, OPT):
    """Compute theoretical bound (1 - (N*ln(2N/delta)/m)^beta) * OPT"""
    term = N * np.log(2 * N / delta) / m
    if term >= 1:
        return 0  # Bound becomes meaningless
    return (1 - term**beta) * OPT

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class NSWSampleSizeAnalyzer:
    """NSW CATE allocation with sample size analysis and theoretical bounds."""

    def __init__(self, random_seed=42):
        self.random_seed = random_seed
        np.random.seed(random_seed)
        print(f"NSW Sample Size Analyzer initialized with seed {random_seed}")

    def process_nsw_data(self, df, outcome_col='re78', treatment_col='treat'):
        """Process NSW dataset for analysis."""
        print(f"Processing NSW data with {len(df)} observations")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        # Check for required columns
        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        # Set up treatment and outcome
        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        # Set up baseline earnings
        if 're75' in df_processed.columns:
            df_processed['baseline_earnings'] = df_processed['re75']
        else:
            df_processed['baseline_earnings'] = 0  # Default if no baseline

        # Clean data
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} individuals")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (1978 earnings) statistics: mean=${df_processed['outcome'].mean():.0f}, std=${df_processed['outcome'].std():.0f}")

        if 'baseline_earnings' in df_processed.columns:
            print(f"Baseline (1975 earnings) stats: mean=${df_processed['baseline_earnings'].mean():.0f}, std=${df_processed['baseline_earnings'].std():.0f}")

        return df_processed

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by key demographic characteristics in NSW data."""
        print(f"Creating NSW demographics groups")

        # Key NSW demographic variables
        demo_features = ['black', 'hispanic', 'married', 'nodegree']

        # Check which features are available
        available_features = [col for col in demo_features if col in df.columns]

        if not available_features:
            print("No demographic variables found")
            return []

        print(f"Using demographic features: {available_features}")

        # Limit to top 3 features to avoid too many combinations
        if len(available_features) > 3:
            available_features = available_features[:3]

        # Remove rows with missing values in these features
        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} individuals")

        if len(df_clean) == 0:
            return []

        # Get unique combinations
        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on age brackets."""
        print(f"Creating age groups (target: {n_groups})")

        if 'age' not in df.columns:
            print("No age variable found")
            return []

        # Create age-based groups
        age = df['age'].fillna(df['age'].median())

        # Create age brackets
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_education_groups(self, df, min_size=6):
        """Create groups based on education levels."""
        print(f"Creating education groups")

        if 'education' not in df.columns:
            print("No education variable found")
            return []

        groups = []
        for education_level in df['education'].unique():
            if pd.isna(education_level):
                continue

            indices = df[df['education'] == education_level].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'education_{education_level}_years',
                    'indices': indices,
                    'type': 'education'
                })

        print(f"Created {len(groups)} education groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_baseline_earnings_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on 1975 baseline earnings."""
        print(f"Creating baseline earnings groups (target: {n_groups})")

        if 'baseline_earnings' not in df.columns or df['baseline_earnings'].sum() == 0:
            print("No baseline earnings data available")
            return []

        # Create earnings-based groups
        earnings = df['baseline_earnings'].fillna(0)  # Fill NaN with 0 for unemployed

        # Create earnings brackets including zero earners
        if (earnings == 0).mean() > 0.3:  # If >30% have zero earnings, create separate zero group
            # Create one group for zero earners
            zero_earners = df.index[earnings == 0].tolist()
            groups = []
            if len(zero_earners) >= min_size:
                groups.append({
                    'id': 'zero_earnings_1975',
                    'indices': zero_earners,
                    'type': 'baseline_earnings'
                })

            # Create groups for positive earners
            positive_earnings = earnings[earnings > 0]
            if len(positive_earnings) > 0:
                percentiles = np.linspace(0, 100, n_groups)
                cuts = np.percentile(positive_earnings, percentiles)

                for i in range(len(cuts) - 1):
                    mask = (earnings > cuts[i]) & (earnings <= cuts[i + 1])
                    indices = df.index[mask].tolist()
                    if len(indices) >= min_size:
                        groups.append({
                            'id': f'earnings_1975_bracket_{i}',
                            'indices': indices,
                            'type': 'baseline_earnings'
                        })
        else:
            # Standard percentile groups
            percentiles = np.linspace(0, 100, n_groups + 1)
            cuts = np.percentile(earnings, percentiles)
            bins = np.digitize(earnings, cuts) - 1

            groups = []
            for i in range(n_groups):
                indices = df.index[bins == i].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'earnings_1975_bracket_{i}',
                        'indices': indices,
                        'type': 'baseline_earnings'
                    })

        print(f"Created {len(groups)} baseline earnings groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_race_ethnicity_groups(self, df, min_size=6):
        """Create groups based on race/ethnicity combinations."""
        print(f"Creating race/ethnicity groups")

        # Create race/ethnicity categories
        def get_race_ethnicity(row):
            if row.get('black', 0) == 1:
                return 'black'
            elif row.get('hispanic', 0) == 1:
                return 'hispanic'
            else:
                return 'white_other'

        if 'black' in df.columns or 'hispanic' in df.columns:
            df['race_ethnicity'] = df.apply(get_race_ethnicity, axis=1)

            groups = []
            for race_eth in df['race_ethnicity'].unique():
                indices = df[df['race_ethnicity'] == race_eth].index.tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'race_{race_eth}',
                        'indices': indices,
                        'type': 'race_ethnicity'
                    })

            print(f"Created {len(groups)} race/ethnicity groups")
            balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
            return balanced_groups
        else:
            print("No race/ethnicity variables found")
            return []

    def create_employment_status_groups(self, df, min_size=6):
        """Create groups based on 1975 employment status."""
        print(f"Creating employment status groups")

        if 'baseline_earnings' not in df.columns:
            print("No baseline earnings data for employment status")
            return []

        # Define employment status based on 1975 earnings
        def get_employment_status(earnings):
            if pd.isna(earnings) or earnings == 0:
                return 'unemployed_1975'
            elif earnings < 5000:  # Low earnings threshold for 1975
                return 'low_earnings_1975'
            else:
                return 'higher_earnings_1975'

        df['employment_status'] = df['baseline_earnings'].apply(get_employment_status)

        groups = []
        for status in df['employment_status'].unique():
            indices = df[df['employment_status'] == status].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'employment_{status}',
                    'indices': indices,
                    'type': 'employment_status'
                })

        print(f"Created {len(groups)} employment status groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """Create groups using Random Forest to predict treatment effects."""
        print(f"Creating causal forest groups (target: {n_groups})")

        # Use NSW covariates
        feature_cols = ['age', 'education', 'black', 'hispanic', 'married', 'nodegree', 're75']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for causal forest")
            return []

        X = df[available_features].copy()

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Train separate models
        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() < 5 or control_mask.sum() < 5:
            print("Not enough treated or control observations for causal forest")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        # Predict CATE and cluster
        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'causal_forest_{i}',
                    'indices': indices,
                    'type': 'causal_forest'
                })

        print(f"Created {len(groups)} causal forest groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        # Use NSW covariates
        feature_cols = ['age', 'education', 'black', 'hispanic', 'married', 'nodegree', 're75']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for propensity scoring")
            return []

        X = df[available_features].copy()

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Get propensity scores
        try:
            prop_scores = cross_val_predict(
                LogisticRegression(random_state=self.random_seed, max_iter=1000),
                X, df['treatment'], method='predict_proba', cv=5
            )[:, 1]
        except Exception as e:
            print(f"Error computing propensity scores: {e}")
            return []

        # Create strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure treatment balance and compute group CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.0f}, {max_cate:.0f}] → [0, 1]")
        return groups

    def simulate_sampling_trial(self, groups, sample_size, trial_seed):
        """Simulate sampling trial"""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Initialize tau estimates
        tau_estimates = np.zeros(n_groups)
        sample_counts = np.zeros(n_groups)

        # Perform sampling: choose group uniformly, sample Bernoulli(tau(u))
        for _ in range(sample_size):
            group_idx = np.random.randint(n_groups)
            sample = np.random.binomial(1, tau_true[group_idx])

            sample_counts[group_idx] += 1
            if sample_counts[group_idx] == 1:
                tau_estimates[group_idx] = sample
            else:
                tau_estimates[group_idx] = ((sample_counts[group_idx] - 1) * tau_estimates[group_idx] + sample) / sample_counts[group_idx]

        # Groups with no samples get estimate 0
        tau_estimates[sample_counts == 0] = 0

        return tau_estimates, sample_counts

    def analyze_sample_size_performance(self, groups, sample_sizes, budget_percentages, n_trials=50):
        """Analyze performance vs sample size."""
        print(f"Analyzing sample size performance with {len(groups)} groups")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        # Calculate budgets
        budgets = [max(1, int(p * n_groups)) for p in budget_percentages]
        print(f"Budget percentages {budget_percentages} → K values {budgets}")

        # Calculate optimal values
        optimal_values = {}
        for i, K in enumerate(budgets):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_values[budget_percentages[i]] = np.sum(tau_true[optimal_indices])

        # Run trials
        results = {bp: {'sample_sizes': [], 'values': [], 'stds': []} for bp in budget_percentages}

        for sample_size in sample_sizes:
            print(f"  Sample size {sample_size}...")

            budget_trial_values = {bp: [] for bp in budget_percentages}

            for trial in range(n_trials):
                tau_estimates, sample_counts = self.simulate_sampling_trial(groups, sample_size, trial)

                for i, K in enumerate(budgets):
                    bp = budget_percentages[i]

                    # Select top K based on estimates
                    selected_indices = np.argsort(tau_estimates)[-K:]

                    # Compute realized value with true tau
                    realized_value = np.sum(tau_true[selected_indices])
                    budget_trial_values[bp].append(realized_value)

            # Store results
            for bp in budget_percentages:
                results[bp]['sample_sizes'].append(sample_size)
                results[bp]['values'].append(np.mean(budget_trial_values[bp]))
                results[bp]['stds'].append(np.std(budget_trial_values[bp]))

        return results, optimal_values

    def plot_sample_size_analysis(self, results, optimal_values, method_name, budget_percentages, n_groups):
        """Create 6 plots (one per budget) for sample size analysis with theoretical bounds - ENHANCED STYLING."""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        # Calculate parameters for theoretical bounds
        delta = 0.05

        print(f"\nPlotting {method_name} (N={n_groups})")
        print("="*60)

        for i, bp in enumerate(budget_percentages):
            ax = axes[i]

            # Get data for this budget
            sample_sizes = results[bp]['sample_sizes']
            values = results[bp]['values']
            stds = results[bp]['stds']
            optimal_val = optimal_values[bp]

            # Normalize all values by optimal value
            values_norm = np.array(values) / optimal_val
            stds_norm = np.array(stds) / optimal_val

            # Plot empirical performance curve
            ax.errorbar(sample_sizes, values_norm, yerr=stds_norm,
                      marker='o', capsize=5, capthick=3, linewidth=6, markersize=8,
                      label='Empirical data', color='blue', alpha=0.8)

            # Plot optimal value (normalized to 1)
            ax.axhline(y=1.0, color='black', linestyle=':', linewidth=2,
                      label='Optimal (1.0)', alpha=0.8)

            m_smooth = np.linspace(min(sample_sizes), max(sample_sizes), 200)

            # Plot reference curves
            ref_curve_05 = [theoretical_bound(m, 0.5, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]
            ref_curve_10 = [theoretical_bound(m, 1.0, n_groups, delta, optimal_val) / optimal_val for m in m_smooth]

            ax.plot(m_smooth, ref_curve_05, 'red', linestyle=(0, (3, 2)), linewidth=6,
                  label='FullCATE', alpha=0.8)
            ax.plot(m_smooth, ref_curve_10, 'green', linestyle=(0, (3, 1, 1, 1)), linewidth=6,
                  label='ALLOC', alpha=0.8)

            # Set labels
            ax.set_xlabel('Sample size', fontsize=23)
            ax.set_ylabel('Normalized allocation value', fontsize=23)
            ax.set_title(f'Budget = {bp*100:.0f}% (K={max(1, int(bp * n_groups))})', fontsize=24, fontweight='bold')

            ax.legend(fontsize=21, framealpha=0.9)
            ax.grid(True, alpha=0.4, linewidth=1)

            # Make tick labels larger
            ax.tick_params(axis='both', which='major', labelsize=16, width=1.5, length=5)

            y_min = 0.2
            y_max = 1.05  # Slightly above optimal
            ax.set_ylim(y_min, y_max)

            # Keep axes normal weight
            for spine in ax.spines.values():
                spine.set_linewidth(1.5)

        plt.suptitle(f'{method_name} (N={n_groups})', fontsize=24, fontweight='bold')
        plt.tight_layout()

        clean_name = method_name.replace(' ', '_').replace('(', '').replace(')', '').replace('-', '_')
        pdf_filename = f"{clean_name}_N{n_groups}_sample_size_analysis.pdf"
        plt.savefig(pdf_filename, format='pdf', dpi=300, bbox_inches='tight')
        print(f"Saved plot as: {pdf_filename}")

        plt.show()

        print(f"Plot complete for {method_name}")


def run_nsw_sample_size_analysis(df_nsw, sample_size_range=None, budget_percentages=None, n_trials=50,
                                outcome_col='re78', treatment_col='treat'):
    """Run sample size analysis on NSW dataset with theoretical bounds."""

    if sample_size_range is None:
        sample_size_range = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]

    if budget_percentages is None:
        budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    print("NSW SAMPLE SIZE ANALYSIS - EMPIRICAL VS THEORETICAL BOUNDS")
    print(f"Sample sizes: {sample_size_range}")
    print(f"Budget percentages: {budget_percentages}")
    print(f"Trials per sample size: {n_trials}")
    print("="*80)

    # Define NSW-specific grouping methods
    methods = [
        ('Demographics', lambda analyzer, df: analyzer.create_demographics_groups(df, min_size=6)),
        ('Race/Ethnicity', lambda analyzer, df: analyzer.create_race_ethnicity_groups(df, min_size=6)),
        ('Age Groups', lambda analyzer, df: analyzer.create_age_groups(df, n_groups=30, min_size=6)),
        ('Education', lambda analyzer, df: analyzer.create_education_groups(df, min_size=6)),
        ('Baseline Earnings', lambda analyzer, df: analyzer.create_baseline_earnings_groups(df, n_groups=30, min_size=6)),
        ('Employment Status', lambda analyzer, df: analyzer.create_employment_status_groups(df, min_size=6)),
        ('Causal Forest', lambda analyzer, df: analyzer.create_causal_forest_groups(df, n_groups=30, min_size=6)),
        ('Propensity Score', lambda analyzer, df: analyzer.create_propensity_groups(df, n_groups=50, min_size=6))
    ]

    all_results = {}

    for method_name, method_func in methods:
        print(f"\n{'='*80}")
        print(f"ANALYZING NSW METHOD: {method_name}")
        print("="*80)

        try:
            analyzer = NSWSampleSizeAnalyzer()
            df_processed = analyzer.process_nsw_data(df_nsw, outcome_col=outcome_col, treatment_col=treatment_col)

            groups = method_func(analyzer, df_processed)

            if len(groups) < 10:
                print(f"Too few groups ({len(groups)}) for {method_name} - skipping")
                continue

            groups = analyzer.normalize_cates(groups)

            # Run sample size analysis
            results, optimal_values = analyzer.analyze_sample_size_performance(
                groups, sample_size_range, budget_percentages, n_trials
            )

            all_results[method_name] = {
                'results': results,
                'optimal_values': optimal_values,
                'n_groups': len(groups)
            }

            # Create plots with theoretical bounds
            print(f"Creating plots for {method_name}...")
            analyzer.plot_sample_size_analysis(
                results, optimal_values, method_name, budget_percentages, len(groups)
            )

            # Print summary
            print(f"\nSummary for {method_name}:")
            print(f"Number of groups: {len(groups)}")
            print("Optimal values by budget:")
            for bp in budget_percentages:
                print(f"  {bp*100:.0f}%: {optimal_values[bp]:.3f}")

        except Exception as e:
            print(f"Error with {method_name}: {e}")
            continue

    return all_results

# Example usage
if __name__ == "__main__":
    # Load NSW dataset
    df_nsw = pd.read_stata('nsw.dta')

    # Run sample size analysis with theoretical bounds
    sample_sizes = [100, 250, 500, 750, 1000, 1200, 1500, 2000, 5000, 10000, 20000]
    budget_percentages = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]

    results = run_nsw_sample_size_analysis(
        df_nsw,
        sample_size_range=sample_sizes,
        budget_percentages=budget_percentages,
        n_trials=50
    )